import torch
from tensorboardX import SummaryWriter
from unet import UNet
writer = SummaryWriter('tensorboard')
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = UNet(n_channels=3, n_classes=1, bilinear=True).to(device)
input = torch.ones(1, 3, 360, 640).to(device)
writer.add_graph(model, input)
